# models

import shutil
import sys

from torch import optim
import torch.nn.modules.loss as Loss
import traceback
import time
import torch
import torch.nn as nn


######################################3
# set seed points
seed_num = 888

torch.manual_seed(seed_num)
torch.cuda.manual_seed_all(seed_num)
random.seed(seed_num)
np.random.seed(seed_num)


####################################
#
# Loading Dataset
#
####################################

#==============================
# In PyTorch, the data tensor should be in the following format:
# (num_Cases, num_channels, img_height, img_width)
#
#=============================

# # Load Training Dataset
# data = loadmat('./T1Dataset_MOLLI5_Training_p1.mat')['data']
# tr_t1w_TI, tr_T1 = data['tr_t1w_TI'][0][0], data['tr_T1'][0][0]
# tr_T1_5 = data['tr_T1_5'][0][0]
#
# #Load Testing dataset and all masks for evaluation
# data = loadmat('./T1Dataset_MOLLI5_Testing.mat')['data']
# tst_t1w_TI, tst_T1 = data['tst_t1w_TI'][0][0], data['tst_T1'][0][0]
# tst_T1_5, tst_mask = data['tst_T1_5'][0][0], data['tst_mask'][0][0]
# tst_LVmask, tst_ROImask, tst_sliceID = data['tst_LVmask'][0][0], data['tst_ROImask'][0][0], data['tst_sliceID'][0][0]


####################################
#
# initializations
#
####################################

model_save_dir = './model_save/'
tensorboard_dir = './tensorboard_dir/'
lr = 0.0001
Validation_Only = False
batch_size = 80
epochs = 2000
multi_GPU = True
net_scale = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_ids = range(0, torch.cuda.device_count())

trialNum = 1
if not os.path.exists(model_save_dir):
    os.makedirs(model_save_dir)
if not os.path.exists(tensorboard_dir):
    os.makedirs(tensorboard_dir)


####################################
#
# Create MyoMapNet Model
#
####################################

class Model(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Model, self).__init__()
		
        self.T1fitNet = nn.Sequential(
            nn.Linear(in_features=in_channels, out_features=200*net_scale),
            nn.LeakyReLU(),
            nn.Linear(in_features=200*net_scale, out_features=200*net_scale),
            nn.LeakyReLU(),
            nn.Linear(in_features=200*net_scale, out_features=100*net_scale),
            nn.LeakyReLU(),
            nn.Linear(in_features=100*net_scale, out_features=100*net_scale),
            nn.LeakyReLU(),
            nn.Linear(in_features=100*net_scale, out_features=50*net_scale),
            nn.LeakyReLU(),
            nn.Linear(in_features=50*net_scale, out_features=out_channels),
        )
    def forward(self, x):
        x = self.T1fitNet(x)
        return x


# For Native T1: MyoMapNet takes 10 input points [5 T1-weighted, 5 Inversion Times] and generate one value (T1 value)
net = Model(10, 1)

# # For Post-contrast T1: MyoMapNet takes 8 input points [4 T1-weighted, 4 Inversion Times] and generate one value (T1 value)
# net = Model(10, 1)


def multiply_elems(x):
    m = 1
    for e in x:
        m *= e
    return m

num_params = 0
for parameters in net.parameters():
    num_params += multiply_elems(parameters.shape)
print('Total number of parameters: {0}'.format(num_params))

if multi_GPU:
    net = torch.nn.DataParallel(net, device_ids=device_ids[:-1]).cuda()
else:
    net.to(device)


def train(net):
    ###########################################
    #
    # INITIALIZATIONS
    #
    ############################################
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.8)

    LOSS = list()
    mseCriterion = Loss.MSELoss()
    l1Criterion = torch.nn.L1Loss()

    vi = 0
    i = 0
    bt = 0

    ###########################################
    #
    # LOAD LATEST (or SPECIFIC) MODEL
    #
    ############################################
    models = os.listdir(model_save_dir)
    models = [m for m in models if m.endswith('.pth')]
    s_epoch = -1 ## -1: load latest model or start from 1 if there is no saved models
                 ##  0: don't load any model; start from model #1
                 ##  num: load model #num

    def load_model(epoch):
        print('loading model at epoch ' + str(epoch))
        model = torch.load(model_save_dir + models[0][0:11] + str(epoch) + '.pth')
        net.load_state_dict(model['state_dict'])
        optimizer.load_state_dict(model['optimizer'])
        try:
            LOSS = model['loss']
            i = model['iteration']
            LOSS = loadmat('{0}mse_Trial{1}'.format(tensorboard_dir, trialNum))['mse']
        except:
            pass
    print(len(models))

    if s_epoch == -1:
        if len(models) == 0:
            s_epoch = 1
        else:
            s_epoch = max([int(epo[11:-4]) for epo in models[:]])
            load_model(s_epoch)
    elif s_epoch == 0:
        s_epoch = 1
    else:
        try:
            load_model(s_epoch)
        except:
            print('Model {0} does not exist!'.format(s_epoch))


    ## copy the code with the model saving directory
    os.system("cp -r {0} {1}".format(os.getcwd(), model_save_dir))
    print('Model copied!')

    tr_N = tr_T1.shape[0]
    tr_lst = list(range(0, tr_N))
    for epoch in range(s_epoch, epochs+1):
        print('epoch {}/{}...'.format(epoch, epochs))
        random.shuffle(tr_lst)

        # adjust_learning_rate(optimizer, epoch)
        try:
            ###########################################
            #
            # Training
            #
            ############################################
            l = 0
            itt = 0
            TAG = 'Training'
            MAX = list()
            if not Validation_Only:
                for idx in range(0, tr_N, batch_size):
                    try:
                        lst = tr_lst[idx:idx+batch_size]
                        X = Variable(torch.FloatTensor(tr_t1w_TI[lst,])).to('cuda:0')
                        y = Variable(torch.FloatTensor(tr_T1[lst,])).to('cuda:0')
                        T1_5 = Variable(torch.FloatTensor(tr_T1_5[lst,])).to('cuda:0')

                        y_pred = net(X)

                    except Exception as e:
                        traceback.print_exc()
                        continue

                    if i > 1 and i % 200 == 0:
                        ifnames = ['Image_epoch_{0}_iter_{1}_sl_{2}'.format(epoch, i, s) for s in range(0, batch_size)]
                        ntensorshow((T1_5, y_pred.reshape((xs[0],1,xs[3],xs[4])), y.reshape((xs[0],1,xs[3],xs[4]))), (0, 0), (0, 3), ('MOLLI-5',  'MyoMapNet', 'MOLLI-8'), saveFigs = True, figname=ifnames)

                    loss = mseCriterion(y_pred, y)

                    LOSS.append(loss.cpu().data.numpy())

                    l += loss.data[0]

                    optimizer.zero_grad()
                    loss.backward()
                    i += 1
                    optimizer.step()

                    print('Epoch: {0} - {1:.3f}%'.format(epoch, 100 * (itt * batch_size) / len(
                        training_DG.dataset.input_IDs))
                          + ' \tIter: ' + str(i)
                          + '\tLoss: {0:.6f}'.format(loss.data[0])
                          )
                    itt += 1
                    is_best = 0
                    if i % 50 == 0:

                        save_checkpoint({'epoch': epoch, 'loss': LOSS, 'arch': 'MyoMapNet_Model1', 'state_dict': net.state_dict(),
                                        'optimizer': optimizer.state_dict(), 'iteration': i,
                                        }, is_best, filename=model_save_dir + 'MODEL_EPOCH{}.pth'.format(epoch))

                avg_loss = batch_size * l / len(training_DG.dataset.input_IDs)
                print('Total Loss : {0:.6f} \t Avg. Loss {1:.6f}'.format(l, avg_loss))

                save_checkpoint({'epoch': epoch, 'loss': LOSS, 'arch': 'MyoMapNet_Model1', 'state_dict': net.state_dict(),
                                 'optimizer': optimizer.state_dict(), 'iteration': i,
                                 }, is_best, filename=model_save_dir + 'MODEL_EPOCH{}.pth'.format(epoch))
            else:
                load_model(epoch)



            #####################################
            #
            # Validation
            #
            #####################################

            vl = 0
            vitt = 0
            vld_mse = 0
            meanT1_error_ds = list()
            T1_5_avg = list()
            ref_T1_avg = list()
            pred_T1_5_avg = list()
            tst_N = tst_T1.shape[0]
            tst_lst = list(range(0,tst_N))
            sl_id = list()
            bs=48

            if epoch < epochs and not Validation_Only:
                continue
            TAG = 'Validation'
            with torch.no_grad():
                for idx in range(0, tst_N, bs):
                    try:

                        X = Variable(torch.FloatTensor(tst_t1w_TI[tst_lst[idx:idx + bs]])).to('cuda:0')
                        y = Variable(torch.FloatTensor(tst_T1[tst_lst[idx:idx + bs]])).to('cuda:0')
                        T1_5 = Variable(torch.FloatTensor(tst_T1_5[tst_lst[idx:idx + bs]])).to('cuda:0')
                        LVmask = Variable(torch.FloatTensor(tst_LVmask[tst_lst[idx:idx + bs]])).to('cuda:0')
                        ROImask = Variable(torch.FloatTensor(tst_ROImask[tst_lst[idx:idx + bs]])).to('cuda:0')
                        bloodmask = Variable(torch.FloatTensor(tst_mask[tst_lst[idx:idx + bs]])).to('cuda:0')
                        sliceID = tst_sliceID[tst_lst[idx:idx + bs]].tolist()

                        y_pred = net(X)

                        fignames = ['tst_ptn{0}_{1}'.format(s.split('/')[-1][:-4].split('_')[0],
                                                                   s.split('/')[-1][:-4].split('_')[1]) for s in sliceID]
                        pred_T1_5 = y_pred.reshape((xs[0], 1, xs[3], xs[4]))
                        ref_T1 = y.reshape((xs[0], 1, xs[3], xs[4]))

                        ntensorshow((T1_5, pred_T1_5, ref_T1), (0, 0), (0, 3),
                                    ('T1_fitting_5tw', 'Network', 'Ref_T1'), saveFigs=True, figname=fignames)

                        def mean_T1(x, mask):
                            meant1 = list()
                            for i in range(0, x.shape[0]):
                                xs = x[i, ]
                                myo_T1 = xs[np.nonzero( mask[i, ].cpu().data.numpy())].cpu().data.numpy()
                                # meant1.append(myo_T1.std())
                                meant1.append(myo_T1.mean())
                            return meant1

                        sl_id.append(sliceID)

                        mask_type = 'LVmask'
                        if mask_type == 'bloodmask':
                            analysis_msk = bloodmask
                        elif mask_type == 'ROImask':
                            analysis_msk = ROImask
                        elif mask_type == 'LVmask':
                            analysis_msk = LVmask


                        pred_T1_5_avg.append(mean_T1(pred_T1_5, analysis_msk))
                        T1_5_avg.append(mean_T1(T1_5, analysis_msk))
                        ref_T1_avg.append(mean_T1(ref_T1, analysis_msk))
                        vi += 1
                        vitt += 1

                        print('Epoch: {0} - {1:.3f}%'.format(epoch, 100 * (vitt * bs) / len(validation_DG.dataset.input_IDs))
                              + ' \tIter: ' + str(vi)
                              + '\tSME: {0:.4f}'.format(mseloss.data[0])
                              + '\tSSIM: {0:.6f}'.format(ssimloss.data[0]))
                              #+ '\tInputMSE: {0:.4f}'.format(inloss.data[0]))
                    except Exception as e:
                        traceback.print_exc()
                        continue

                saveArrayToMat(sl_id, 'sl_id')
                saveArrayToMat(np.array(pred_T1_5_avg), 'pred_T1_5_avg')
                saveArrayToMat(np.array(ref_T1_avg), 'ref_T1_avg')
                saveArrayToMat(np.array(T1_5_avg), 'T1_5_avg')

                avg_factor = bs / len(validation_DG.dataset.input_IDs)
                print('Avg. MSE : {0:.6f}'.format(vld_mse * avg_factor)
                      + '\tAvg. SSIM : {0:.6f}'.format(vld_ssim * avg_factor)
                      + '\tAvg. PSNR : {0:.6f}'.format(vld_psnr * avg_factor))

        except Exception as e:
            traceback.print_exc()
            continue
    writer.close()

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    print('Model Saved!')
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 1000 epochs"""
    new_lr = lr * (0.1 ** (epoch // 1000))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr

def tensorshow(x, sl_dims=(0, 0), rng=[0, 0.001]):
    plt.figure()

    if type(x) is np.ndarray:
        if np.iscomplexobj(x):
            x = np.abs(x)[sl_dims[0], sl_dims[1], :, :]
        else:
            x = np.sqrt(x[sl_dims[0], sl_dims[1], :, :, 0] ** 2 +
                               x[sl_dims[0], sl_dims[1], :, :, 1] ** 2)
    else:
        x = magnitude(x).cpu().data.numpy()[sl_dims[0], sl_dims[1], :, :]

    plt.imshow(x, cmap='gray', vmin=rng[0], vmax=rng[1])
    plt.show()

def ntensorshow(x, sl_dims=(0, 0), rng=[0, 0.001], titles=None, saveFigs=False,figname=None):

    n_slices = x[0].shape[0]
    if figname is None:
        figname = [str(i).zfill(3) for i in range(0,n_slices)]

    for sl in range(0,n_slices):
        fig, axs = plt.subplots(1, len(x))
        i = 0
        for ax in axs:
            if titles is not None:
                ax.set_title(titles[i])

            img = x[i]
            if img.shape[1] > 1:
                img = combine_coils_RSOS(img)[sl, sl_dims[1], :, :]
            if img.shape[-1] == 2:
                img = magnitude(img)
            img = img[sl, sl_dims[1], :, :]
            img_mean = torch.mean(torch.reshape(img, [img.numel()]))
            img_std = torch.std(torch.reshape(img, [img.numel()]))
            ax.imshow(img.cpu().data.numpy(), cmap='jet', vmin=0, vmax=2400)
            ax.axis('off')

            i += 1

        if saveFigs:
            fig.savefig(tensorboard_dir + figname[sl] +'.png', dpi=300)
            plt.close(fig)
        else:
            fig.show()
            fig.canvas.flush_events()


try:
    train(net)

except KeyboardInterrupt:
    print('Interrupted')
    torch.save(net.state_dict(), 'MODEL_INTERRUPTED.pth')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)